Skip to content

Enable Qwen3.5-MoE PTQ#897

Merged
Edwardf0t1 merged 10 commits intomainfrom
zhiyu/qwen3p5-moe-support
Feb 27, 2026
Merged

Enable Qwen3.5-MoE PTQ#897
Edwardf0t1 merged 10 commits intomainfrom
zhiyu/qwen3p5-moe-support

Conversation

@Edwardf0t1
Copy link
Contributor

@Edwardf0t1 Edwardf0t1 commented Feb 16, 2026

What does this PR do?

Type of change: New model support

Overview: Add ModelOpt PTQ support for https://huggingface.co/Qwen/Qwen3.5-397B-A17B

Usage

python3 hf_ptq.py --pyt_ckpt_path /home/omniml_data_3/models/Qwen3.5-397B-A17B --qformat nvfp4_mlp_only --export_path /home/omniml_data_3/zhiyuc/checkpoints/Qwen3.5-397B-A17B-NVFP4 --trust_remote_code

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes

Additional Information

Summary by CodeRabbit

  • New Features

    • Added Qwen3.5 Mixture-of-Experts model support in quantization workflows.
  • Bug Fixes

    • Enhanced error diagnostics during model export with detailed module information.
    • Improved dataset tokenizer processing with proper truncation and length handling.
    • Fixed model export stability issue related to framework integration.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 16, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 16, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

The pull request adds support for Qwen3.5 MoE quantization across the export pipeline, improves error reporting during quantized weight export with module-aware diagnostics, introduces a patching mechanism to handle transformer-related weight conversion issues, enhances tokenizer encoding with truncation parameters, and relocates custom model file copying in PTQ workflows.

Changes

Cohort / File(s) Summary
PTQ Example
examples/llm_ptq/hf_ptq.py
Relocates copy_custom_model_files call to after tokenizer export to ensure original tokenizer files take precedence over regenerated ones, with expanded explanatory comments.
MOE Recognition Expansion
modelopt/torch/export/layer_utils.py
Adds "Qwen3_5MoeSparseMoeBlock" to MOE detection logic in is_moe, module_match_name_list, and get_expert_linear_names. Improves amax handling in set_expert_quantizer_amax to treat zero-valued tensors as needing recalibration. Adjusts QKV splitting with Pythonic truthiness check.
Export Error Handling & Patching
modelopt/torch/export/unified_export_hf.py
Enhances error diagnostics by capturing module names during quantized weight export and wrapping failures with detailed context. Introduces reversible patching mechanism (_revert_weight_conversion_noop, _try_patch_module, _patch_revert_weight_conversion, _unpatch_revert_weight_conversion) to safely disable transformers' weight conversion during HF checkpoint export.
Qwen3.5 MoE Quantization
modelopt/torch/quantization/plugins/huggingface.py
Adds three new quantized module classes (_Qwen35MoeExpertModule, _QuantQwen35MoeExperts, _QuantQwen35MoeSparseMoeBlock) with expert-aware routing and per-expert module management. Registers new modules in QuantModuleRegistry for Qwen3.5 MoE support. Includes guarded imports for optional HF integration.
Dataset Tokenization
modelopt/torch/utils/dataset_utils.py
Updates tokenizer invocation from batch_encode_plus to direct call with added truncation=True and max_length parameters. Minor comment clarification regarding in-place tokenizer modifications.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Enable Qwen3.5-MoE PTQ' directly and accurately summarizes the main objective of the PR: adding post-training quantization support for the Qwen3.5 MoE model.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch zhiyu/qwen3p5-moe-support

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Feb 16, 2026

Codecov Report

❌ Patch coverage is 0% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 72.03%. Comparing base (a6cbcba) to head (c7bb291).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/utils/dataset_utils.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #897   +/-   ##
=======================================
  Coverage   72.03%   72.03%           
=======================================
  Files         207      207           
  Lines       22718    22718           
=======================================
  Hits        16365    16365           
  Misses       6353     6353           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Edwardf0t1 Edwardf0t1 marked this pull request as ready for review February 25, 2026 02:02
@Edwardf0t1 Edwardf0t1 requested review from a team as code owners February 25, 2026 02:02
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/export/layer_utils.py (2)

1085-1102: ⚠️ Potential issue | 🟠 Major

Zero amax values are still collected into valid_amax_values, causing them to silently perpetuate when all experts are uncalibrated.

The new needs_amax condition (line 1142) correctly identifies all-zero amax tensors as invalid. However, the valid_amax_values collection loop (lines 1089–1098) only checks existing_amax is not None, so zero tensors are collected. When every expert has amax == 0:

  1. valid_amax_values = [0, 0, ...]
  2. target_amax = torch.max(stack([0, ...])) = 0
  3. elif target_amax is None branch (line 1105) is skipped — weight-stat fallback never runs
  4. needs_amax = Truequantizer.amax is set to 0 again
  5. Warning misleadingly says "Setting it to 0.000000 (max from existing quantizers in current batch)"

The fix is to mirror the needs_amax predicate in the collection loop:

🐛 Proposed fix
     valid_amax_values = []
     for _, attr_name, quantizer in all_quantizers:
         existing_amax = getattr(quantizer, "amax", None)
-        if existing_amax is not None:
+        if existing_amax is not None and not (
+            isinstance(existing_amax, torch.Tensor) and torch.all(existing_amax == 0)
+        ):
             # Convert to tensor and add to collection
             if isinstance(existing_amax, torch.Tensor):
                 valid_amax_values.append(existing_amax.to(target_device))
             else:
                 valid_amax_values.append(
                     torch.tensor(existing_amax, dtype=torch.float32, device=target_device)
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/layer_utils.py` around lines 1085 - 1102, When
collecting existing amax values into valid_amax_values (loop over all_quantizers
/ existing_amax), only append values that are non-zero using the same predicate
as needs_amax: convert existing_amax to a torch.Tensor on target_device first
and check it's not all zeros (e.g., tensor.ne(0).any()); skip appending if the
tensor is all zeros or None so target_amax won't be set to 0 when all experts
are uncalibrated, allowing the weight-stat fallback to run and avoiding
misleading warnings for quantizer.amax.

328-344: ⚠️ Potential issue | 🟠 Major

get_experts_list does not handle Qwen3.5 model type — will raise NotImplementedError at runtime.

is_moe and get_expert_linear_names correctly recognize Qwen3_5MoeSparseMoeBlock, but get_experts_list (lines 91–99) dispatches on model type strings extracted as type(model).__name__.lower(). A Qwen3.5 model class like Qwen3_5MoeForCausalLM produces "qwen3_5moeforcausallm", which is not matched in the function's checks. If quantization uses AWQ or NVFP4_SVDQUANT, the code at line 297 in unified_export_hf.py executes get_experts_list(module, model_type) for any MoE module detected by is_moe, triggering NotImplementedError at line 102.

Add "qwen3_5moeforcausallm" to the model type checks in get_experts_list.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/layer_utils.py` around lines 328 - 344,
get_experts_list currently dispatches on type(model).__name__.lower() and lacks
handling for Qwen3.5 class names, so add "qwen3_5moeforcausallm" to the
model-type checks inside get_experts_list to match the same Qwen3_5 detection
used by is_moe and get_expert_linear_names; update the conditional branches that
compare model_type (from type(model).__name__.lower()) to include
"qwen3_5moeforcausallm" so get_experts_list returns the correct expert list
instead of raising NotImplementedError (also verify any other qwen3_5 variants
present in that same dispatch and add them if missing).
🧹 Nitpick comments (1)
examples/llm_ptq/hf_ptq.py (1)

649-653: Avoid duplicate custom-file copy in the TensorRT-LLM path.

Now that the canonical copy happens at Line 653 (after tokenizer export), the earlier TensorRT-LLM copy at Line 616 becomes redundant and adds extra I/O.

♻️ Proposed cleanup
-            # Copy custom model files (Python files and JSON configs) for TensorRT-LLM export
-            copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 649 - 653, Remove the duplicate
custom-file copy in the TensorRT-LLM export path: there is an earlier call to
copy_custom_model_files that runs during the TensorRT-LLM branch which becomes
redundant because the canonical copy occurs later via
copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code)
after tokenizer.save_pretrained(); delete the earlier call (and any now-unused
surrounding conditional) so custom Python/JSON files are only copied once.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1060-1074: The patch/unpatch sequence around
_patch_revert_weight_conversion/_unpatch_revert_weight_conversion mutates
process globals and must be serialized to avoid cross-export races; add a
module-level lock (e.g., threading.RLock) and acquire it before calling
_patch_revert_weight_conversion, keep it held across model.save_pretrained(...)
and the finally block, then release after _unpatch_revert_weight_conversion so
only one export at a time can patch globals; update any helper initialization to
use the new lock and ensure exceptions still trigger unpatch+release.

In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 913-919: The calibration block in _QuantSparseMoe.forward
temporarily sets self.gate.top_k to self.num_experts but restores it only after
calling super(...).forward, so an exception during calibration leaves top_k
mutated; wrap the calibration call in a try/finally: save original_top_k, set
self.gate.top_k = self.num_experts, call super(_QuantSparseMoe,
self).forward(hidden_states) inside try, and in finally always restore
self.gate.top_k = original_top_k to guarantee restoration even on errors.

---

Outside diff comments:
In `@modelopt/torch/export/layer_utils.py`:
- Around line 1085-1102: When collecting existing amax values into
valid_amax_values (loop over all_quantizers / existing_amax), only append values
that are non-zero using the same predicate as needs_amax: convert existing_amax
to a torch.Tensor on target_device first and check it's not all zeros (e.g.,
tensor.ne(0).any()); skip appending if the tensor is all zeros or None so
target_amax won't be set to 0 when all experts are uncalibrated, allowing the
weight-stat fallback to run and avoiding misleading warnings for quantizer.amax.
- Around line 328-344: get_experts_list currently dispatches on
type(model).__name__.lower() and lacks handling for Qwen3.5 class names, so add
"qwen3_5moeforcausallm" to the model-type checks inside get_experts_list to
match the same Qwen3_5 detection used by is_moe and get_expert_linear_names;
update the conditional branches that compare model_type (from
type(model).__name__.lower()) to include "qwen3_5moeforcausallm" so
get_experts_list returns the correct expert list instead of raising
NotImplementedError (also verify any other qwen3_5 variants present in that same
dispatch and add them if missing).

---

Nitpick comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 649-653: Remove the duplicate custom-file copy in the TensorRT-LLM
export path: there is an earlier call to copy_custom_model_files that runs
during the TensorRT-LLM branch which becomes redundant because the canonical
copy occurs later via copy_custom_model_files(args.pyt_ckpt_path, export_path,
args.trust_remote_code) after tokenizer.save_pretrained(); delete the earlier
call (and any now-unused surrounding conditional) so custom Python/JSON files
are only copied once.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ca1f968 and 4219853.

📒 Files selected for processing (5)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/export/layer_utils.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/utils/dataset_utils.py

Comment on lines +1060 to 1074
# Temporarily disable revert_weight_conversion if available — it doesn't handle
# quantized state dicts (scalar scale tensors have 0 dimensions, causing IndexError).
# We must patch both the source module and the importing module since
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
_patches = _patch_revert_weight_conversion()

try:
model.save_pretrained(
export_dir,
state_dict={**post_state_dict, **(extra_state_dict or {})},
save_modelopt_state=save_modelopt_state,
)
finally:
_unpatch_revert_weight_conversion(_patches)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Serialize global patching to avoid cross-export races.

The patch/unpatch sequence mutates module globals process-wide. Concurrent exports can interleave and restore the wrong function, causing flaky behavior.

🔒 Proposed fix (serialize patch window)
+import threading
...
+_REVERT_WEIGHT_CONVERSION_PATCH_LOCK = threading.Lock()
...
-        _patches = _patch_revert_weight_conversion()
-
-        try:
-            model.save_pretrained(
-                export_dir,
-                state_dict={**post_state_dict, **(extra_state_dict or {})},
-                save_modelopt_state=save_modelopt_state,
-            )
-        finally:
-            _unpatch_revert_weight_conversion(_patches)
+        with _REVERT_WEIGHT_CONVERSION_PATCH_LOCK:
+            _patches = _patch_revert_weight_conversion()
+            try:
+                model.save_pretrained(
+                    export_dir,
+                    state_dict={**post_state_dict, **(extra_state_dict or {})},
+                    save_modelopt_state=save_modelopt_state,
+                )
+            finally:
+                _unpatch_revert_weight_conversion(_patches)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Temporarily disable revert_weight_conversion if available — it doesn't handle
# quantized state dicts (scalar scale tensors have 0 dimensions, causing IndexError).
# We must patch both the source module and the importing module since
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
_patches = _patch_revert_weight_conversion()
try:
model.save_pretrained(
export_dir,
state_dict={**post_state_dict, **(extra_state_dict or {})},
save_modelopt_state=save_modelopt_state,
)
finally:
_unpatch_revert_weight_conversion(_patches)
# Temporarily disable revert_weight_conversion if available — it doesn't handle
# quantized state dicts (scalar scale tensors have 0 dimensions, causing IndexError).
# We must patch both the source module and the importing module since
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
with _REVERT_WEIGHT_CONVERSION_PATCH_LOCK:
_patches = _patch_revert_weight_conversion()
try:
model.save_pretrained(
export_dir,
state_dict={**post_state_dict, **(extra_state_dict or {})},
save_modelopt_state=save_modelopt_state,
)
finally:
_unpatch_revert_weight_conversion(_patches)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 1060 - 1074, The
patch/unpatch sequence around
_patch_revert_weight_conversion/_unpatch_revert_weight_conversion mutates
process globals and must be serialized to avoid cross-export races; add a
module-level lock (e.g., threading.RLock) and acquire it before calling
_patch_revert_weight_conversion, keep it held across model.save_pretrained(...)
and the finally block, then release after _unpatch_revert_weight_conversion so
only one export at a time can patch globals; update any helper initialization to
use the new lock and ensure exceptions still trigger unpatch+release.

Comment on lines +913 to +919
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
# Force all tokens to all experts during calibration
original_top_k = self.gate.top_k
self.gate.top_k = self.num_experts
super(_QuantSparseMoe, self).forward(hidden_states)
self.gate.top_k = original_top_k
return super(_QuantSparseMoe, self).forward(hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Always restore gate.top_k with finally.

If the calibration warmup forward throws, self.gate.top_k remains at self.num_experts, which can corrupt later routing behavior.

🛠️ Proposed fix
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
             # Force all tokens to all experts during calibration
             original_top_k = self.gate.top_k
             self.gate.top_k = self.num_experts
-            super(_QuantSparseMoe, self).forward(hidden_states)
-            self.gate.top_k = original_top_k
+            try:
+                super(_QuantSparseMoe, self).forward(hidden_states)
+            finally:
+                self.gate.top_k = original_top_k
         return super(_QuantSparseMoe, self).forward(hidden_states)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
# Force all tokens to all experts during calibration
original_top_k = self.gate.top_k
self.gate.top_k = self.num_experts
super(_QuantSparseMoe, self).forward(hidden_states)
self.gate.top_k = original_top_k
return super(_QuantSparseMoe, self).forward(hidden_states)
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
# Force all tokens to all experts during calibration
original_top_k = self.gate.top_k
self.gate.top_k = self.num_experts
try:
super(_QuantSparseMoe, self).forward(hidden_states)
finally:
self.gate.top_k = original_top_k
return super(_QuantSparseMoe, self).forward(hidden_states)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 913 - 919,
The calibration block in _QuantSparseMoe.forward temporarily sets
self.gate.top_k to self.num_experts but restores it only after calling
super(...).forward, so an exception during calibration leaves top_k mutated;
wrap the calibration call in a try/finally: save original_top_k, set
self.gate.top_k = self.num_experts, call super(_QuantSparseMoe,
self).forward(hidden_states) inside try, and in finally always restore
self.gate.top_k = original_top_k to guarantee restoration even on errors.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds ModelOpt PTQ (Post-Training Quantization) support for the Qwen3.5-MoE model (Qwen/Qwen3.5-397B-A17B). The changes include new quantization module support, improved error diagnostics, bug fixes for tokenizer handling, and improvements to the model export workflow.

Changes:

  • Added _QuantQwen35MoeExperts quantization module for Qwen3.5 MoE architecture
  • Improved MoE layer detection with pattern-based matching (auto-detect modules ending with "sparsemoeblock")
  • Enhanced error handling in model export with detailed module information
  • Fixed tokenizer encoding to use modern tokenizer() method instead of deprecated batch_encode_plus
  • Added patching for revert_weight_conversion to handle quantized state dicts with scalar tensors
  • Improved zero amax handling in expert quantizer calibration
  • Fixed file copy order to preserve original tokenizer files over regenerated ones

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
modelopt/torch/utils/dataset_utils.py Updated tokenizer call from batch_encode_plus to modern tokenizer() method and improved comment
modelopt/torch/quantization/plugins/huggingface.py Added _QuantQwen35MoeExperts and _Qwen35MoeExpertModule classes for Qwen3.5 MoE support with registration
modelopt/torch/export/unified_export_hf.py Enhanced error messages with module details and added revert_weight_conversion patching to handle scalar tensors
modelopt/torch/export/layer_utils.py Simplified MoE detection with pattern matching, added Qwen3_5MoeSparseMoeBlock to expert linear names, improved amax handling for zero values, and simplified conditional expression
examples/llm_ptq/hf_ptq.py Moved copy_custom_model_files call after tokenizer.save_pretrained to preserve original tokenizer files

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

with torch.no_grad():
module.weight.data = weight.detach().data.to(dtype=dtype, device=device)

expert_dim = self.intermediate_dim
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _QuantQwen35MoeExperts class directly accesses self.intermediate_dim without fallback handling, unlike _QuantQwen3VLMoeTextExperts which checks for both intermediate_size and intermediate_dim attributes. If the Qwen3.5 model uses intermediate_size instead of intermediate_dim, this will cause an AttributeError. Consider adding the same fallback logic used in _QuantQwen3VLMoeTextExperts (lines 671-676).

Suggested change
expert_dim = self.intermediate_dim
# Support both `intermediate_size` and `intermediate_dim` depending on the model config.
if hasattr(self, "intermediate_size"):
expert_dim = self.intermediate_size
else:
expert_dim = self.intermediate_dim

Copilot uses AI. Check for mistakes.
@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/qwen3p5-moe-support branch from 9aaca69 to ffb50e4 Compare February 26, 2026 16:47
@Edwardf0t1 Edwardf0t1 enabled auto-merge (squash) February 26, 2026 19:32
@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/qwen3p5-moe-support branch from ffb50e4 to aa60527 Compare February 26, 2026 20:59
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
@Edwardf0t1 Edwardf0t1 force-pushed the zhiyu/qwen3p5-moe-support branch from 6af1cd4 to c7bb291 Compare February 26, 2026 22:52
@Edwardf0t1 Edwardf0t1 merged commit a415667 into main Feb 27, 2026
37 checks passed
@Edwardf0t1 Edwardf0t1 deleted the zhiyu/qwen3p5-moe-support branch February 27, 2026 00:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants